"""
Phase 5: Unified Coefficient Surface
======================================
Synthesize Phase 2-4 results into a "coefficient surface" showing how
Z₁'s effect on each DV varies across institutional regimes.

Produces:
1. Implied Z₁ effect for each country archetype
2. Heat map data (archetype × DV matrix)
3. Key summary for the paper narrative
"""

import pandas as pd
import numpy as np
from pathlib import Path
import sys
import warnings
warnings.filterwarnings('ignore')

PROJECT_DIR = Path(__file__).resolve().parent.parent
ROOT_DIR = PROJECT_DIR.parent
sys.path.insert(0, str(ROOT_DIR / "multilateral" / "src"))
from model import PanelGLS

DATA = PROJECT_DIR / "data" / "processed"
OUT_TABLES = PROJECT_DIR / "output" / "tables"

Z_VARS = ['Z_1', 'Z_2', 'Z_3']
DVS = ['ca_gdp', 'gross_savings_gdp', 'gross_investment_gdp', 'nfa_gdp']

# Country archetypes defined by moderator combinations
ARCHETYPES = {
    'Low-income, closed': {'income_low': 1, 'income_high': 0, 'is_oecd': 0,
                           'safe_issuer': 0, 'eurozone': 0, 'kaopen_saturated': 0},
    'Middle-income, closed': {'income_low': 0, 'income_high': 0, 'is_oecd': 0,
                               'safe_issuer': 0, 'eurozone': 0, 'kaopen_saturated': 0},
    'Middle-income, open': {'income_low': 0, 'income_high': 0, 'is_oecd': 0,
                             'safe_issuer': 0, 'eurozone': 0, 'kaopen_saturated': 1},
    'High-income, non-OECD': {'income_low': 0, 'income_high': 1, 'is_oecd': 0,
                               'safe_issuer': 0, 'eurozone': 0, 'kaopen_saturated': 0},
    'OECD, non-safe': {'income_low': 0, 'income_high': 1, 'is_oecd': 1,
                        'safe_issuer': 0, 'eurozone': 0, 'kaopen_saturated': 1},
    'OECD, EMU': {'income_low': 0, 'income_high': 1, 'is_oecd': 1,
                   'safe_issuer': 0, 'eurozone': 1, 'kaopen_saturated': 1},
    'Safe issuer': {'income_low': 0, 'income_high': 1, 'is_oecd': 1,
                     'safe_issuer': 1, 'eurozone': 0, 'kaopen_saturated': 1},
}

# Moderator -> interaction suffix
MOD_SUFFIX = {
    'income_low': 'low', 'income_high': 'high', 'is_oecd': 'oecd',
    'safe_issuer': 'safe', 'eurozone': 'emu', 'kaopen_saturated': 'ksat',
}


def stars(p):
    if p < 0.01: return '***'
    if p < 0.05: return '**'
    if p < 0.1: return '*'
    return ''


def run_gls(df, y_var, x_vars):
    cols = [y_var] + x_vars + ['iso3', 'year']
    sub = df[cols].dropna()
    if len(sub) < 50:
        return None
    gls = PanelGLS()
    try:
        gls.fit(sub[y_var].values, sub[x_vars].values,
                sub['iso3'].values, sub['year'].values)
    except Exception:
        return None
    result = {'n_obs': gls.n_obs, 'r_squared': gls.r_squared}
    for i, var in enumerate(x_vars):
        result[f'{var}_coef'] = gls.beta[i]
        result[f'{var}_se'] = gls.se[i]
        result[f'{var}_p'] = gls.pvalues[i]
    return result


def main():
    print("Phase 5: Unified Coefficient Surface")
    print("=" * 70)

    df = pd.read_csv(DATA / "unified_panel.csv")

    # Run the "kitchen sink" model: all moderator interactions simultaneously
    all_int_vars = []
    for mod, suffix in MOD_SUFFIX.items():
        for z in ['Z_1', 'Z_2', 'Z_3']:
            col = f'{z}_x_{suffix}'
            if col in df.columns:
                all_int_vars.append(col)

    x_full = Z_VARS + all_int_vars
    print(f"Full model: {len(x_full)} regressors\n")

    surface = {}  # {dv: {archetype: implied_z1_effect}}

    for dv in DVS:
        if dv not in df.columns:
            continue

        m = run_gls(df, dv, x_full)
        if m is None:
            print(f"  {dv}: estimation failed")
            continue

        print(f"\n{dv}: N={m['n_obs']}, R²={m['r_squared']:.3f}")
        z1_base = m['Z_1_coef']
        z1_base_p = m['Z_1_p']
        print(f"  Z₁ base: {z1_base:.3f}{stars(z1_base_p)}")

        # Print all Z₁ interactions
        for mod, suffix in MOD_SUFFIX.items():
            key = f'Z_1_x_{suffix}'
            if f'{key}_coef' in m:
                c, p = m[f'{key}_coef'], m[f'{key}_p']
                print(f"  Z₁×{mod}: {c:.3f}{stars(p)} (p={p:.3f})")

        # Compute implied effects for each archetype
        surface[dv] = {}
        for arch_name, arch_mods in ARCHETYPES.items():
            implied = z1_base
            for mod, val in arch_mods.items():
                if val == 1 and mod in MOD_SUFFIX:
                    suffix = MOD_SUFFIX[mod]
                    key = f'Z_1_x_{suffix}_coef'
                    if key in m:
                        implied += m[key]
            surface[dv][arch_name] = implied
            print(f"    {arch_name}: Z₁ effect = {implied:.1f}")

    # ── Write coefficient surface table ───────────────────────────────
    print(f"\n{'=' * 70}")
    print("Writing output tables...")

    with open(OUT_TABLES / "phase5_coefficient_surface.md", 'w') as f:
        f.write("# Phase 5: Implied Z₁ Effect by Country Archetype\n\n")
        f.write("Estimated from a single PanelGLS model with all moderator interactions.\n\n")

        # Header
        dvs_present = [dv for dv in DVS if dv in surface]
        dv_labels = {'ca_gdp': 'CA/GDP', 'gross_savings_gdp': 'Savings/GDP',
                     'gross_investment_gdp': 'Investment/GDP', 'nfa_gdp': 'NFA/GDP'}
        header = "| Archetype | " + " | ".join(dv_labels.get(d, d) for d in dvs_present) + " |"
        f.write(header + "\n")
        f.write("|---" * (1 + len(dvs_present)) + "|\n")

        for arch_name in ARCHETYPES:
            vals = []
            for dv in dvs_present:
                v = surface[dv].get(arch_name, np.nan)
                vals.append(f"{v:.1f}")
            f.write(f"| {arch_name} | " + " | ".join(vals) + " |\n")

        f.write("\n*Values show implied Z₁ coefficient (effect of one-unit increase "
                "in demographic PC₁ on DV) for each archetype.*\n")
        f.write("*Positive CA = aging → surplus; Positive Savings = aging → more savings.*\n")
        f.write("*Estimated via PanelGLS with AR(1) correction.*\n")
    print("  Wrote: phase5_coefficient_surface.md")

    # ── Key findings summary ──────────────────────────────────────────
    with open(OUT_TABLES / "phase5_key_findings.md", 'w') as f:
        f.write("# Phase 5: Key Findings — Coefficient Surface\n\n")

        # Find the range for CA
        if 'ca_gdp' in surface:
            ca_vals = surface['ca_gdp']
            min_arch = min(ca_vals, key=ca_vals.get)
            max_arch = max(ca_vals, key=ca_vals.get)
            f.write("## Current Account\n")
            f.write(f"- **Range**: {ca_vals[min_arch]:.1f} ({min_arch}) to "
                    f"{ca_vals[max_arch]:.1f} ({max_arch})\n")
            f.write(f"- **Spread**: {ca_vals[max_arch] - ca_vals[min_arch]:.1f} pp/GDP\n")
            f.write(f"- The sign of demography's effect on the CA **reverses** across "
                    f"institutional regimes.\n\n")

        if 'gross_savings_gdp' in surface:
            sv = surface['gross_savings_gdp']
            min_a = min(sv, key=sv.get)
            max_a = max(sv, key=sv.get)
            f.write("## Savings\n")
            f.write(f"- **Range**: {sv[min_a]:.1f} ({min_a}) to "
                    f"{sv[max_a]:.1f} ({max_a})\n")
            f.write(f"- **Spread**: {sv[max_a] - sv[min_a]:.1f} pp/GDP\n\n")

        f.write("## Implication\n")
        f.write("A single Z₁ coefficient is meaningless without specifying the institutional "
                "context. The 'average' demographic effect masks regime-specific effects that "
                "differ not just in magnitude but in sign.\n")
    print("  Wrote: phase5_key_findings.md")

    print(f"\nPhase 5 complete.")


if __name__ == '__main__':
    main()
